{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding new objects to Nengo\n",
    "\n",
    "It is possible to add new objects\n",
    "to the Nengo reference simulator.\n",
    "This involves several steps and the creation\n",
    "of several objects.\n",
    "In this example, we'll go through these steps\n",
    "in order to add a new neuron type to Nengo:\n",
    "a rectified linear neuron.\n",
    "\n",
    "## Step 1: Create a frontend Neurons subclass\n",
    "\n",
    "The `RectifiedLinear` class is what you will use\n",
    "in model scripts to denote that a particular ensemble\n",
    "should be simulated using a rectified linear neuron\n",
    "instead of one of the existing neuron types (e.g., `LIF`).\n",
    "\n",
    "Normally, these kinds of frontend classes exist\n",
    "in either `nengo/objects.py` or `nengo/neurons.py`.\n",
    "Look at these files for examples of how to make your own.\n",
    "In this case, because we're making a neuron type,\n",
    "we'll use `nengo.neurons.LIF` as an example\n",
    "of how to make `RectifiedLinear`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import nengo\n",
    "from nengo.builder import Builder\n",
    "from nengo.builder.operator import Operator\n",
    "from nengo.utils.ensemble import tuning_curves\n",
    "\n",
    "\n",
    "# Neuron types must subclass `nengo.neurons.NeuronType`\n",
    "class RectifiedLinear(nengo.neurons.NeuronType):\n",
    "    \"\"\"A rectified linear neuron model.\"\"\"\n",
    "\n",
    "    # We don't need any additional parameters here;\n",
    "    # gain and bias are sufficient. But, if we wanted\n",
    "    # more parameters, we could accept them by creating\n",
    "    # an __init__ method.\n",
    "\n",
    "    def gain_bias(self, max_rates, intercepts):\n",
    "        \"\"\"Return gain and bias given maximum firing rate and x-intercept.\"\"\"\n",
    "        gain = max_rates / (1 - intercepts)\n",
    "        bias = -intercepts * gain\n",
    "        return gain, bias\n",
    "\n",
    "    def step_math(self, dt, J, output):\n",
    "        \"\"\"Compute rates in Hz for input current (incl. bias)\"\"\"\n",
    "        output[...] = np.maximum(0., J)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Create a backend Operator subclass\n",
    "\n",
    "The `Operator` (located in `nengo/builder.py`) defines\n",
    "the function that the reference simulator will execute\n",
    "on every timestep. Most new neuron types and learning rules\n",
    "will require a new `Operator`, unless the function\n",
    "being computed can be done by combining several\n",
    "existing operators.\n",
    "\n",
    "In this case, we will make a new operator\n",
    "that outputs the firing rate of each neuron\n",
    "on every timestep.\n",
    "\n",
    "Note that for neuron types specifically,\n",
    "there is a `SimNeurons` operator that\n",
    "calls `step_math`. However, we will\n",
    "implement a new operator here to demonstrate\n",
    "how to build a simple operator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimRectifiedLinear(Operator):\n",
    "    \"\"\"Set output to the firing rate of a rectified linear neuron model.\"\"\"\n",
    "\n",
    "    def __init__(self, output, J, neurons, tag=None):\n",
    "        super().__init__(tag=tag)\n",
    "        self.neurons = neurons  # The RectifiedLinear instance\n",
    "\n",
    "        # Operators must explicitly tell the simulator what signals\n",
    "        # they read, set, update, and increment\n",
    "        self.reads = [J]\n",
    "        self.updates = [output]\n",
    "        self.sets = []\n",
    "        self.incs = []\n",
    "\n",
    "    @property\n",
    "    def output(self):\n",
    "        \"\"\"Output signal of the ensemble.\"\"\"\n",
    "        return self.updates[0]\n",
    "\n",
    "    @property\n",
    "    def J(self):\n",
    "        return self.reads[0]\n",
    "\n",
    "    # If we needed additional signals that aren't in one of the\n",
    "    # reads, updates, sets, or incs lists, we can initialize them\n",
    "    # by making an `init_signals(self, signals, dt)` method.\n",
    "\n",
    "    def make_step(self, signals, dt, rng):\n",
    "        \"\"\"Return a function that the Simulator will execute on each step.\n",
    "\n",
    "        `signals` contains a dictionary mapping each signal to\n",
    "        an ndarray which can be used in the step function.\n",
    "        `dt` is the simulator timestep (which we don't use).\n",
    "        \"\"\"\n",
    "        J = signals[self.J]\n",
    "        output = signals[self.output]\n",
    "\n",
    "        def step_simrectifiedlinear():\n",
    "            # Gain and bias are already taken into account here,\n",
    "            # so we just need to rectify\n",
    "            output[...] = np.maximum(0, J)\n",
    "\n",
    "        return step_simrectifiedlinear"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Create a build function\n",
    "\n",
    "In order for `nengo.builder.Builder`\n",
    "to construct signals and operators\n",
    "for the Simulator to use,\n",
    "you must create and register a build function\n",
    "with `nengo.builder.Builder`.\n",
    "This function should take as arguments\n",
    "a `RectifiedLinear` instance,\n",
    "some other arguments specific to the type,\n",
    "and a `nengo.builder.Model` instance.\n",
    "The function should add the approrpiate\n",
    "signals, operators, and other artifacts\n",
    "to the `Model` instance,\n",
    "and then register the build function\n",
    "with `nengo.builder.Builder`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@Builder.register(RectifiedLinear)\n",
    "def build_rectified_linear(model, neuron_type, neurons):\n",
    "    model.operators.append(\n",
    "        SimRectifiedLinear(\n",
    "            output=model.sig[neurons]['out'],\n",
    "            J=model.sig[neurons]['in'],\n",
    "            neurons=neuron_type))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you can use `RectifiedLinear` like any other neuron type!\n",
    "\n",
    "## Tuning curves\n",
    "\n",
    "We can build a small network just to see the tuning curves."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nengo.Network()\n",
    "with model:\n",
    "    encoders = np.tile([[1], [-1]], (4, 1))\n",
    "    intercepts = np.linspace(-0.8, 0.8, 8)\n",
    "    intercepts *= encoders[:, 0]\n",
    "    A = nengo.Ensemble(\n",
    "        8, dimensions=1,\n",
    "        intercepts=intercepts,\n",
    "        neuron_type=RectifiedLinear(),\n",
    "        max_rates=nengo.dists.Uniform(80, 100),\n",
    "        encoders=encoders)\n",
    "with nengo.Simulator(model) as sim:\n",
    "    eval_points, activities = tuning_curves(A, sim)\n",
    "plt.figure()\n",
    "plt.plot(eval_points, activities, lw=2)\n",
    "plt.xlabel(\"Input signal\")\n",
    "plt.ylabel(\"Firing rate (Hz)\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2D Representation example\n",
    "\n",
    "Below is the same model as is made in the 2d_representation example,\n",
    "except now using `RectifiedLinear` neurons insated of `nengo.LIF`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nengo.Network(label='2D Representation', seed=10)\n",
    "with model:\n",
    "    neurons = nengo.Ensemble(100, dimensions=2, neuron_type=RectifiedLinear())\n",
    "    sin = nengo.Node(output=np.sin)\n",
    "    cos = nengo.Node(output=np.cos)\n",
    "    nengo.Connection(sin, neurons[0])\n",
    "    nengo.Connection(cos, neurons[1])\n",
    "    sin_probe = nengo.Probe(sin, 'output')\n",
    "    cos_probe = nengo.Probe(cos, 'output')\n",
    "    neurons_probe = nengo.Probe(neurons, 'decoded_output', synapse=0.01)\n",
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(sim.trange(), sim.data[neurons_probe], label=\"Decoded output\")\n",
    "plt.plot(sim.trange(), sim.data[sin_probe], 'r', label=\"Sine\")\n",
    "plt.plot(sim.trange(), sim.data[cos_probe], 'k', label=\"Cosine\")\n",
    "plt.legend();"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
